{
  "cells": [
      {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YONnGjpAYUdU"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/google-research/language/blob/master/language/decontext/decontextualization_demo.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uXtHFqejrysP"
      },
      "source": [
        "# Decontextualization Demo\n",
        "\n",
        "This colab contains a T5 model for decontextualizing individual \n",
        "sentences. The decontextualization task is described in\n",
        "[Decontextualization: Making Sentences Stand-Alone](https://arxiv.org/abs/2102.05169).\n",
        "\n",
        "Please cite as:\n",
        "\n",
        "```\n",
        "@article{choi2021making,\n",
        "  title = {Decontextualization: Making Sentences Stand-Alone},\n",
        "  author = {Eunsol Choi and Jennimaria Palomaki and Matthew Lamm and Tom Kwiatkowski and Dipanjan Das and Michael Collins},\n",
        "  year = {2021},\n",
        "  journal = {Transactions of the Association of Computational Linguistics}\n",
        "}\n",
        "```\n",
        "\n",
        "## Input format\n",
        "The Decontextualization model is trained on Wikipedia pages. The input is made\n",
        "up of the page title; the (possibly empty) section titles; and a paragraph that is split into a prefix, the target sentence (to be decontextualized), and a\n",
        "suffix. The model input should have the form:\n",
        "\n",
        "```\u003cpage title\u003e [SEP] \u003csection title\u003e [SEP] \u003cpreceeding sentences\u003e [SEP] \u003ctarget sentence\u003e [SEP] \u003csucceeding sentences\u003e```\n",
        "\n",
        "where any of the fields apart from `\u003ctarget sentence\u003e` may be empty, but all of\n",
        "the `[SEP]` tokens should be included.",
        "\n",
        "You can download the labeled train, dev, and test splits on the [paper's GitHub site](https://github.com/google-research/language/tree/master/language/decontext)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qcr6nOuM6KK1"
      },
      "source": [
        "## Load a tuned T5 model\n",
        "\n",
        "Choose which model you'd like to load, and define a prediction function.\n",
        "\n",
        "These models are tuned versions of the\n",
        "[released T5 models](https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints).\n",
        "Details of the model tuning are available in [the paper](https://arxiv.org/abs/2102.05169).\n",
        " Be warned that T5-11B is very slow on CPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "cellView": "form",
        "executionInfo": {
          "elapsed": 23094,
          "status": "ok",
          "timestamp": 1612534947157,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "KUGNUKuZ5Vky",
        "outputId": "9b9ec682-735f-44d3-f839-1378a06d67c9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading SavedModel in eager mode.\n"
          ]
        }
      ],
      "source": [
        "print(\"Installing dependencies...\")\n",
        "!pip install -q tensorflow_text\n",
        "\n",
        "from os import path\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_text  # Required to run exported model.\n",
        "\n",
        "MODEL_SIZE = \"base\" #@param[\"base\", \"3B\", \"11B\"]\n",
        "\n",
        "DATASET_BUCKET = 'gs://decontext_dataset'\n",
        "\n",
        "SAVED_MODELS = {\n",
        "  \"base\": f'{DATASET_BUCKET}/t5_base/1611267950',\n",
        "  \"3B\": f'{DATASET_BUCKET}/t5_3B/1611333896',\n",
        "  \"11B\": f'{DATASET_BUCKET}/t5_11B/1605298402'\n",
        "}\n",
        "\n",
        "SAVED_MODEL_PATH = SAVED_MODELS[MODEL_SIZE]\n",
        "DEV = path.join(DATASET_BUCKET, 'decontext_dev.jsonl')\n",
        "SAVED_MODEL_PATH = path.join(DATASET_BUCKET, 't5_base/1611267950')\n",
        "\n",
        "def load_predict_fn(model_path):\n",
        "  print(\"Loading SavedModel in eager mode.\")\n",
        "  imported = tf.saved_model.load(model_path, [\"serve\"])\n",
        "  return lambda x: imported.signatures['serving_default'](\n",
        "      tf.constant(x))['outputs'].numpy()\n",
        "\n",
        "predict_fn = load_predict_fn(SAVED_MODEL_PATH)\n",
        "\n",
        "def decontextualize(input):\n",
        "  return predict_fn([input])[0].decode('utf-8')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A4YkuJPs-hQe"
      },
      "source": [
        "## Try on some of your own input\n",
        "\n",
        "Type in a paragraph, one sentence per line, as well as the page title and \n",
        "any section titles.\n",
        "Then, indicate which sentence you would like to decontextualize and run the \n",
        "model in prediction mode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "executionInfo": {
          "elapsed": 10042,
          "status": "ok",
          "timestamp": 1612534957209,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "rnE7nr7nH99B",
        "outputId": "2ba8f079-646b-4a53-b089-f792bcbec517"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Original sentence:         They were married on 7 November of the same year, the same day Gagarin graduated from his flight school, and they had two daughters.\n",
            "Decontextualized sentence: DONE #### Yuri Gagarin and Valentina Goryacheva were married on 7 November of the same year, the same day Gagarin graduated from his flight school, and they had two daughters.\n"
          ]
        }
      ],
      "source": [
        "paragraph = [\n",
        "  \"Gagarin was a keen sportsman and played ice hockey as a goalkeeper.\",\n",
        "  \"He was also a basketball fan and coached the Saratov Industrial Technical School team, as well as being a referee.\",\n",
        "  \"In 1957, while a cadet in flight school, Gagarin met Valentina Goryacheva at the May Day celebrations at the Red Square in Moscow.\",\n",
        "  \"She was a medical technician who had graduated from Orenburg Medical School.\",\n",
        "  \"They were married on 7 November of the same year, the same day Gagarin graduated from his flight school, and they had two daughters.\",\n",
        "  \"Yelena Yurievna Gagarina, born 1959, is an art historian who has worked as the director-general of the Moscow Kremlin Museums since 2001; and Galina Yurievna Gagarina, born 1961, is a professor of economics and the department chair at Plekhanov Russian University of Economics in Moscow.\"\n",
        "]\n",
        "\n",
        "page_title = 'Yuri Gagarin'\n",
        "section_title = 'Personal Life'  # can be empty\n",
        "target_sentence_idx = 4  # zero-based index\n",
        "\n",
        "\n",
        "if target_sentence_idx \u003e= len(paragraph) or target_sentence_idx \u003c 0:\n",
        "  raise ValueError(\n",
        "      f'Target sentence index must be in range [0, {len(paragraph) - 1}].')\n",
        "\n",
        "\n",
        "def create_input(paragraph,\n",
        "                 target_sentence_idx,\n",
        "                 page_title='',\n",
        "                 section_title=''):\n",
        "  \"\"\"Creates a single Decontextualization example input for T5.\n",
        "\n",
        "  Args:\n",
        "    paragraph: List of strings. Each string is a single sentence.\n",
        "    target_sentence_idx: Integer index into `paragraph` indicating which\n",
        "      sentence should be decontextualized.\n",
        "    page_title: Optional title string. Usually Wikipedia page title.\n",
        "    section_title: Optional title of section within page.\n",
        "  \"\"\"\n",
        "  prefix = ' '.join(paragraph[:target_sentence_idx])\n",
        "  target = paragraph[target_sentence_idx]\n",
        "  suffix = ' '.join(paragraph[target_sentence_idx + 1:])\n",
        "  return ' [SEP] '.join((page_title, section_title, prefix, target, suffix))\n",
        "\n",
        "d = decontextualize(\n",
        "        create_input(paragraph, target_sentence_idx, page_title,\n",
        "                     section_title))\n",
        "print(f'Original sentence:         {paragraph[target_sentence_idx]}\\n'\n",
        "      f'Decontextualized sentence: {d}')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//third_party/py/t5:notebook",
        "kind": "private"
      },
      "name": "decontextualization-demo",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
